{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Concise Implementation of Multilayer Perceptron\n",
    "\n",
    ":label:`sec_mlp_djl`\n",
    "\n",
    "\n",
    "As you might expect, by relying on the DJL library,\n",
    "we can implement MLPs even more concisely. <br>\n",
    "Let's setup the relevant libraries first."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n",
    "\n",
    "%maven ai.djl:api:0.7.0-SNAPSHOT\n",
    "%maven ai.djl:model-zoo:0.7.0-SNAPSHOT\n",
    "%maven ai.djl:basicdataset:0.7.0-SNAPSHOT\n",
    "%maven org.slf4j:slf4j-api:1.7.26\n",
    "%maven org.slf4j:slf4j-simple:1.7.26\n",
    "%maven ai.djl.mxnet:mxnet-engine:0.7.0-SNAPSHOT\n",
    "%maven ai.djl.mxnet:mxnet-native-auto:1.7.0-a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%loadFromPOM\n",
    "<dependency>\n",
    "    <groupId>tech.tablesaw</groupId>\n",
    "    <artifactId>tablesaw-jsplot</artifactId>\n",
    "    <version>0.30.4</version>\n",
    "</dependency>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load ../utils/plot-utils.ipynb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import java.nio.file.*;\n",
    "import ai.djl.Device;\n",
    "import ai.djl.*;\n",
    "import ai.djl.metric.*;\n",
    "import ai.djl.ndarray.*;\n",
    "import ai.djl.ndarray.types.*;\n",
    "import ai.djl.ndarray.index.*;\n",
    "import ai.djl.nn.*;\n",
    "import ai.djl.nn.core.*;\n",
    "import ai.djl.training.*;\n",
    "import ai.djl.training.initializer.*;\n",
    "import ai.djl.training.loss.*;\n",
    "import ai.djl.training.listener.*;\n",
    "import ai.djl.training.evaluator.*;\n",
    "import ai.djl.training.optimizer.*;\n",
    "import ai.djl.training.optimizer.learningrate.*;\n",
    "import ai.djl.training.dataset.*;\n",
    "import ai.djl.util.*;\n",
    "import java.util.Random;\n",
    "import java.util.stream.LongStream;\n",
    "import ai.djl.basicdataset.FashionMnist;\n",
    "import ai.djl.training.dataset.Dataset;\n",
    "import tech.tablesaw.api.*;\n",
    "import tech.tablesaw.plotly.api.*;\n",
    "import tech.tablesaw.plotly.components.*;\n",
    "import tech.tablesaw.plotly.Plot;\n",
    "import tech.tablesaw.plotly.components.Figure;\n",
    "import org.apache.commons.lang3.ArrayUtils;"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The Model\n",
    "\n",
    "As compared to our gluon implementation \n",
    "of softmax regression implementation\n",
    "(:numref:`sec_softmax_gluon`),\n",
    "the only difference is that we add \n",
    "*two* `Linear` (fully-connected) layers \n",
    "(previously, we added *one*).\n",
    "The first is our hidden layer, \n",
    "which contains *256* hidden units\n",
    "and applies the ReLU activation function.\n",
    "The second is our output layer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "5"
    }
   },
   "outputs": [],
   "source": [
    "SequentialBlock net = new SequentialBlock();\n",
    "net.add(Blocks.batchFlattenBlock(784));\n",
    "net.add(Linear.builder().setOutChannels(256).build());\n",
    "net.add(Activation::relu);\n",
    "net.add(Linear.builder().setOutChannels(10).build());\n",
    "net.setInitializer(new NormalInitializer());"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that DJL, as usual, automatically\n",
    "infers the missing input dimensions to each layer.\n",
    "\n",
    "The training loop is *exactly* the same\n",
    "as when we implemented softmax regression.\n",
    "This modularity enables us to separate \n",
    "matters concerning the model architecture\n",
    "from orthogonal considerations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "int batchSize = 256;\n",
    "int numEpochs = 10;\n",
    "double[] trainLoss;\n",
    "double[] testAccuracy;\n",
    "double[] epochCount;\n",
    "double[] trainAccuracy;\n",
    "\n",
    "trainLoss = new double[numEpochs];\n",
    "trainAccuracy = new double[numEpochs];\n",
    "testAccuracy = new double[numEpochs];\n",
    "epochCount = new double[numEpochs];\n",
    "\n",
    "FashionMnist trainIter = FashionMnist.builder()\n",
    "                            .optUsage(Dataset.Usage.TRAIN)\n",
    "                            .setSampling(batchSize, true)\n",
    "                            .build();\n",
    "\n",
    "\n",
    "FashionMnist testIter = FashionMnist.builder()\n",
    "                            .optUsage(Dataset.Usage.TEST)\n",
    "                            .setSampling(batchSize, true)\n",
    "                            .build();\n",
    "\n",
    "trainIter.prepare();\n",
    "testIter.prepare();\n",
    "\n",
    "for(int i = 0; i < epochCount.length; i++) {\n",
    "    epochCount[i] = (i + 1);\n",
    "}\n",
    "\n",
    "Map<String, double[]> evaluatorMetrics = new HashMap<>();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "6"
    }
   },
   "outputs": [],
   "source": [
    "LearningRateTracker lrt = LearningRateTracker.fixedLearningRate(0.5f);\n",
    "Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();\n",
    "\n",
    "Loss loss = Loss.softmaxCrossEntropyLoss();\n",
    "\n",
    "DefaultTrainingConfig config = new DefaultTrainingConfig(loss)\n",
    "                .optOptimizer(sgd) // Optimizer (loss function)\n",
    "                .addEvaluator(new Accuracy()) // Model Accuracy\n",
    "                .addTrainingListeners(TrainingListener.Defaults.basic()); // Logging\n",
    "\n",
    "    try (Model model = Model.newInstance(\"mlp\")) {\n",
    "        model.setBlock(net);\n",
    "\n",
    "        try (Trainer trainer = model.newTrainer(config)) {\n",
    "\n",
    "            trainer.initialize(new Shape(1, 784));\n",
    "            trainer.setMetrics(new Metrics());\n",
    "\n",
    "            EasyTrain.fit(trainer, numEpochs, trainIter, testIter);\n",
    "            // collect results from evaluators\n",
    "            Metrics metrics = trainer.getMetrics();\n",
    "\n",
    "            trainer.getEvaluators().stream()\n",
    "                    .forEach(evaluator -> {\n",
    "                        evaluatorMetrics.put(\"train_epoch_\" + evaluator.getName(), metrics.getMetric(\"train_epoch_\" + evaluator.getName()).stream()\n",
    "                                            .mapToDouble(x -> x.getValue().doubleValue()).toArray());\n",
    "                        evaluatorMetrics.put(\"validate_epoch_\" + evaluator.getName(), metrics.getMetric(\"validate_epoch_\" + evaluator.getName()).stream()\n",
    "                                            .mapToDouble(x -> x.getValue().doubleValue()).toArray());\n",
    "            });\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainLoss = evaluatorMetrics.get(\"train_epoch_SoftmaxCrossEntropyLoss\");\n",
    "trainAccuracy = evaluatorMetrics.get(\"train_epoch_Accuracy\");\n",
    "testAccuracy = evaluatorMetrics.get(\"validate_epoch_Accuracy\");\n",
    "\n",
    "String[] lossLabel = new String[trainLoss.length + testAccuracy.length + trainAccuracy.length];\n",
    "\n",
    "Arrays.fill(lossLabel, 0, trainLoss.length, \"test acc\");\n",
    "Arrays.fill(lossLabel, trainAccuracy.length, trainLoss.length + trainAccuracy.length, \"train acc\");\n",
    "Arrays.fill(lossLabel, trainLoss.length + trainAccuracy.length,\n",
    "                trainLoss.length + testAccuracy.length + trainAccuracy.length, \"train loss\");\n",
    "\n",
    "Table data = Table.create(\"Data\").addColumns(\n",
    "            DoubleColumn.create(\"epochCount\", ArrayUtils.addAll(epochCount, ArrayUtils.addAll(epochCount, epochCount))),\n",
    "            DoubleColumn.create(\"loss\", ArrayUtils.addAll(testAccuracy , ArrayUtils.addAll(trainAccuracy, trainLoss))),\n",
    "            StringColumn.create(\"lossLabel\", lossLabel)\n",
    ");\n",
    "\n",
    "render(LinePlot.create(\"\", data, \"epochCount\", \"loss\", \"lossLabel\"),\"text/html\");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Exercises\n",
    "\n",
    "1. Try adding different numbers of hidden layers. What setting (keeping other parameters and hyperparameters constant) works best? \n",
    "1. Try out different activation functions. Which ones work best?\n",
    "1. Try different schemes for initializing the weights. What method works best?\n",
    "\n",
    "## [Discussions](https://discuss.mxnet.io/t/2340)\n",
    "\n",
    "![](../img/qr_mlp-gluon.svg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "11.0.5+10-LTS"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
